import copy

from omegaconf import DictConfig, OmegaConf
import hydra
from hydra.utils import get_original_cwd, to_absolute_path, instantiate
import os
import logging
from src.pl_model.classification_model import ClassificationModel  #TODO fix this
from src.utils.config_utils import log_splits_info
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from src.pl_model.N_classification_model import N_ClassificationModel
import copy
import numpy as np


log = logging.getLogger(__name__)


@hydra.main(config_path="conf", config_name="train_config")
def my_app(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    print(f"Current working directory : {os.getcwd()}")
    print(f"Orig working directory    : {get_original_cwd()}")
    print(f"to_absolute_path('foo')   : {to_absolute_path('foo')}")
    print(f"to_absolute_path('/foo')  : {to_absolute_path('/foo')}")
    log.info("hey info")
    log.debug("hey debug")
    # log.info(instantiate(cfg.model))
    if cfg.num_clients == 1:
        log.info(f"Overriding split function since there is only one client")
        cfg.split_function._target_ = "src.data.data_utils.no_split"

    log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")

    pl.seed_everything(cfg.datamodule.seed)
    datamodule = instantiate(cfg.datamodule, _recursive_=False)
    cfg.model.model.num_classes = datamodule.num_classes

    log.info(f"Instantiating logger <{cfg.logger._target_}>")
    logger = instantiate(cfg.logger)

    # if cfg.same_initial:
    #     model0 = instantiate(cfg.model.model)
    #     pl_model = ClassificationModel(cfg=cfg,learner_model=copy.deepcopy(model0))
    #
    # else:
    pl_model = ClassificationModel(cfg=cfg)

    for current_client_idx in range(cfg.num_clients):
        # we have to refresh them for every iter
        callbacks = []
        if "callbacks" in cfg:
            for _, cb_conf in cfg.callbacks.items():
                if "_target_" in cb_conf:
                    log.info(f"Instantiating callback <{cb_conf._target_}>")
                    if cb_conf._target_ == "pytorch_lightning.callbacks.EarlyStopping":
                        cb_conf.monitor = f"val_client-{current_client_idx}/acc"
                    elif cb_conf._target_ == "pytorch_lightning.callbacks.ModelCheckpoint":
                        cb_conf.monitor = f"val_client-{current_client_idx}/acc"
                        cb_conf.dirpath = f"models/client-{current_client_idx}"
                        cb_conf.filename = "model"
                    callbacks.append(instantiate(cb_conf))
        # for cb in callbacks:
        #     if type(cb) == pl.callbacks.EarlyStopping:
        #         cb.monitor = f"val_client-{current_client_idx}/acc"
        trainer: Trainer = instantiate(cfg.trainer, logger=logger, callbacks=callbacks)
        trainer.fit(model=pl_model, datamodule=datamodule)

        trainer.validate(model=pl_model, datamodule=datamodule)
        trainer.test(model=pl_model, datamodule=datamodule)

        # trainer.test(datamodule=datamodule, ckpt_path='best') # this will pick the best ckpt, also if none is given it will load the best as well
        logger.experiment.summary[
            f"client-{current_client_idx}_best-val-acc"] = trainer.checkpoint_callback.best_model_score
        logger.experiment.summary[
            f"client-{current_client_idx}_best-model-path"] = trainer.checkpoint_callback.best_model_path

        # log it only once
        if current_client_idx == 0:
            # TODO: Log the split, Log the distrbution of the split
            # logger.experiment.summary["splits_info"] = log_splits_info(
            #     datamodule.datasets_train, datamodule.datasets_val, datamodule.fair_val
            # )
            splits_info = log_splits_info(datamodule.datasets_train, datamodule.datasets_val, datamodule.fair_val)
            processed_splits_info = convert_numpy_to_python(splits_info)
            logger.experiment.summary["splits_info"] = processed_splits_info


        print("2- len(self.model.state_dict().items())",
              len(pl_model.model.state_dict().items()))


        # TODO: add test step to measure true performance and log it
        # TODO: make sure that the best model get saved and loaded!
        # trainer.test(model=pl_model)
        if current_client_idx == cfg.num_clients - 1:
            continue  # to avoid error from the assert
        datamodule.next_client()
    # TODO: Query Based? Automatically generated quires or inputted quires
    # if auto it can be within the same scrip, if input it need to be in a diff-
    # erent script
    # TODO: Take the saved models, and their best acc, and log it in the transfer experiment



def convert_numpy_to_python(data):
    if isinstance(data, dict):
        return {k: convert_numpy_to_python(v) for k, v in data.items()}
    elif isinstance(data, list):
        return [convert_numpy_to_python(item) for item in data]
    elif isinstance(data, np.integer):
        return int(data)
    elif isinstance(data, np.floating):
        return float(data)
    elif isinstance(data, np.ndarray):
        return data.tolist()
    else:
        return data

if __name__ == "__main__":
    my_app()
